from dataclasses import dataclass, field
from enum import IntEnum
import logging
import re
import os
from typing import Any, Dict, List, Optional, Tuple
from ofrak.model.resource_model import ResourceIndexedAttribute
from ofrak.resource import Resource
from ofrak.core.dtb import DtbNode, DtbProperty
from ofrak.core import Addressable
from ofrak.core.filesystem import FilesystemEntry
from ofrak.service.resource_service_i import ResourceAttributeValueFilter, ResourceFilter
from ofrak_type import Range
from black import format_str, FileMode


LOGGER = logging.getLogger(__name__)


class SelectableAttributesError(Exception):
    """
    Prompt the user for an attribute to select with
    """


class ActionType(IntEnum):
    UNPACK = 0
    MOD = 1
    PACK = 2
    UNDEF = 3


@dataclass(frozen=True)
class ScriptAction:
    """
    Encapsulates the structure of a single action within the script, which consists of the string
    representation of the code for that action and the action's type.
    """

    action_type: ActionType
    action: str


@dataclass
class ScriptSession:
    """
    A script, consisting of an ordered sequence of script actions and a mapping between resources
    and their autogenerated variable names.
    """

    actions_queue: List[ScriptAction] = field(default_factory=list)
    actions: List[ScriptAction] = field(default_factory=list)
    resource_variable_names: Dict[bytes, str] = field(default_factory=dict)
    resource_variable_names_queue: Dict[bytes, str] = field(default_factory=dict)

    boilerplate_header: str = r"""
    from ofrak import *
    from ofrak.core import *
    from ofrak.gui.script_builder import get_child_by_range

    async def main(ofrak_context: OFRAKContext, root_resource: Optional[Resource] = None):"""
    # TODO: Replace with backend in use by OFRAK instance used to create the script.
    boilerplate_footer: str = r"""


    if __name__ == "__main__":
        ofrak = OFRAK()
        if False:
            import ofrak_angr
            import ofrak_capstone

            ofrak.discover(ofrak_capstone)
            ofrak.discover(ofrak_angr)

        if False:
            import ofrak_binary_ninja
            import ofrak_capstone

            ofrak.discover(ofrak_capstone)
            ofrak.discover(ofrak_binary_ninja)

        if False:  # older Ghidra backend with Java server
            import ofrak_ghidra

            ofrak.discover(ofrak_ghidra)

        if False:  # newer PyGhidra backend
            import ofrak_pyghidra

            ofrak.discover(ofrak_pyghidra)

        ofrak.run(main)
    """

    def get_var_name(self, id: bytes) -> str:
        if id in self.resource_variable_names:
            return self.resource_variable_names[id]
        elif id in self.resource_variable_names_queue:
            return self.resource_variable_names_queue[id]
        else:
            raise ValueError(f"No variable name for resource ID 0x{id.hex()}")


class ScriptBuilder:
    """
    Builds and maintains runnable OFRAK scripts as sequences of actions, with each script tied to
    a session.
    """

    def __init__(self):
        self.root_cache: Dict[bytes, Resource] = {}
        self.script_sessions: Dict[bytes, ScriptSession] = {}
        self.selectable_indexes: List[ResourceIndexedAttribute] = [
            FilesystemEntry.Name,
            Addressable.VirtualAddress,
            DtbNode.DtbNodeName,
            DtbProperty.DtbPropertyName,
        ]

    async def add_action(
        self,
        resource: Resource,
        action: str,
        action_type: ActionType,
    ) -> None:
        """
        Adds an action to the script session queue to which the selected resource belongs. An action
        is a string representing the code that is being run on the resource based on an action that
        has occurred in the GUI.

        Actions are queued so that invalid actions which result in runtime exceptions within OFRAK
        do not make it into the final script. Once an action is queued and the corresponding OFRAK
        calls have run, the caller must explicitly call `clear_script_queue` or `commit_to_script`
        depending on whether an exception was raised or not, respectively.

        :param resource: Resource upon which the action is being taken
        :param action: A string describing the code being run based on a GUI action
        :param action_type: An instance of `ActionType` categorizing the action
        """
        var_name = await self._add_variable(resource)
        qualified_action = action.format(resource=var_name)
        await self._add_action_to_session_queue(resource, qualified_action, action_type)

    async def get_script(self, resource: Resource) -> List[str]:
        """
        Returns the most up-to-date version of the script for the session to which the resource
        belongs.

        :param resource: Resource belonging to the session for which the script is to be returned

        :return: List of strings where each entry is a line in the script
        """
        root_resource = await self._get_root_resource(resource)
        return self._get_script(root_resource.get_id())

    async def commit_to_script(self, resource: Resource) -> None:
        """
        Commits the staged actions and variable names in the queue to the script session following
        a one or more valid actions being run.

        :param resource: Resource belonging to the session whose queue will be committed
        """
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        for id, name in session.resource_variable_names_queue.items():
            session.resource_variable_names[id] = name
        session.actions += session.actions_queue
        session.actions_queue = []
        session.resource_variable_names_queue = {}

    async def clear_script_queue(self, resource: Resource) -> None:
        """
        Clears the script session queue of all staged actions and variable names following an
        invalid action being run.

        :param resource: Resource belonging to the session whose queue will be cleared
        """
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        session.actions_queue = []
        session.resource_variable_names_queue = {}

    async def _add_variable(self, resource: Resource) -> str:
        """
        Replaces references to a particular resource selected in the GUI with a generated variable
        name based on uniquely identifying characteristics of the resource. This overcomes the issue
        of referencing the same resource across OFRAK contexts due to the randomly generated
        resource IDs changing.

        :param resource: Resource that needs to be uniquely identified in the script

        :return: a unique variable name
        """
        if await self._var_exists(resource):
            return await self._get_variable_from_session(resource)

        root_resource = await self._get_root_resource(resource)
        if resource.get_id() == root_resource.get_id():
            await self._add_variable_to_session_queue(resource, "root_resource")
            return "root_resource"

        parent = await resource.get_parent()
        if not await self._var_exists(parent):
            await self._add_variable(parent)

        name = ""
        # Cannot propagate exceptions to the server as this would interfere with user actions
        # regardless of whether they're interested in the script. Currently only _get_selector()
        # and _generate_name() can lead to exceptions raised within ScriptBuilder.
        try:
            selector = await self._get_selector(resource)
            name = await self._generate_name(resource)
            await self._add_action_to_session_queue(
                resource,
                rf"""
        {name} = {selector}""",
                ActionType.UNDEF,
            )
            await self._add_variable_to_session_queue(resource, name)
        except SelectableAttributesError as e:
            name = await self._generate_missing_name(resource, e)
            LOGGER.exception("Could not find selectable attributes for resource")
            return name
        except:
            LOGGER.exception("Exception raised in add_variable")
        return name

    async def _generate_missing_name(self, resource: Resource, e: Exception) -> str:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        parent = await resource.get_parent()
        parent_name = await self._get_variable_from_session(parent)
        name = f"{parent_name}_MISSING_RESOURCE_0"
        index = 1
        var_names = list(session.resource_variable_names.values()) + list(
            session.resource_variable_names_queue.values()
        )
        name = self._increment_missing_name_index(name, var_names, index)
        await self._add_action_to_session_queue(
            resource,
            f"""
        # Resource with parent {parent_name} is missing, could not find selectable attributes.
        raise RuntimeError(\"{str(e)}\")
        {name} = None""",
            ActionType.UNDEF,
        )
        await self._add_variable_to_session_queue(resource, name)
        return name

    def _increment_missing_name_index(self, name: str, var_names: List[str], index: int) -> str:
        name_template = "_".join(name.split("_")[:-1])
        if name in var_names:
            name = name_template + f"_{index}"
            index += 1
            return self._increment_missing_name_index(name, var_names, index)
        else:
            return name

    async def _get_root_resource(self, resource: Resource) -> Resource:
        """
        Maps a given resource to its root for efficient retrieval of the root resource because
        getting the root resource is likely the most performed operation in `ScriptBuilder`.
        """
        resource_id = resource.get_id()
        if resource_id in self.root_cache:
            return self.root_cache[resource_id]
        ancestors = list(await resource.get_ancestors())
        root = ancestors[-1] if ancestors else resource
        self.root_cache[resource_id] = root
        return root

    async def _get_variable_from_session(self, resource: Resource) -> str:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        return session.get_var_name(resource.get_id())

    async def _var_exists(self, resource: Resource) -> bool:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        return (
            resource.get_id() in session.resource_variable_names
            or resource.get_id() in session.resource_variable_names_queue
        )

    async def _add_action_to_session_queue(
        self, resource: Resource, action: str, action_type: ActionType
    ) -> None:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        session.actions_queue.append(ScriptAction(action_type, action))

    async def _add_variable_to_session_queue(self, resource: Resource, var_name: str) -> None:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        session.resource_variable_names_queue[resource.get_id()] = var_name

    async def _test_selectable_attributes(
        self,
        ancestor: Resource,
        resource: Resource,
        attribute: ResourceIndexedAttribute,
        attribute_value: Any,
    ) -> None:
        try:
            result = await ancestor.get_only_child(
                r_filter=ResourceFilter(
                    attribute_filters=[
                        ResourceAttributeValueFilter(attribute=attribute, value=attribute_value)
                    ],
                )
            )
        except Exception as e:
            raise SelectableAttributesError(
                f"Resource with ID 0x{resource.get_id().hex()} cannot be uniquely identified by attribute "
                f"{attribute.__name__} (resource has value {attribute_value})."
            )

    async def _test_get_child_by_range(
        self,
        ancestor: Resource,
        resource: Resource,
        range_in_parent: Range,
    ):
        retrieved_child = await get_child_by_range(ancestor, range_in_parent)
        if retrieved_child.get_id() != resource.get_id():
            raise SelectableAttributesError(
                f"Resource with ID 0x{resource.get_id().hex()} cannot be uniquely identified by "
                f"data range: {range_in_parent}"
            )

    async def _get_selector(self, resource: Resource) -> str:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        parent = await resource.get_parent()
        try:
            attribute, attribute_value = await self._get_selectable_attribute(resource)
            await self._test_selectable_attributes(parent, resource, attribute, attribute_value)
            attribute_value = f"{attribute_value!r}"
            return f"""await {session.get_var_name(parent.get_id())}.get_only_child(
                        r_filter=ResourceFilter(
                            tags={resource.get_most_specific_tags()},
                            attribute_filters=[
                                ResourceAttributeValueFilter(
                                    attribute={attribute.__name__}, 
                                    value={attribute_value}
                                )
                            ]   
                        )
                    )"""
        except SelectableAttributesError:
            range_in_parent = await resource.get_data_range_within_parent()
            await self._test_get_child_by_range(parent, resource, range_in_parent)
            return f"""await get_child_by_range(
                        {session.get_var_name(parent.get_id())}, 
                        Range({range_in_parent.start}, {range_in_parent.end})
                    )
            """

    async def _get_self_and_siblings_matching_attribute(
        self, resource: Resource, attribute: ResourceIndexedAttribute, attribute_value: Any
    ) -> List[Resource]:
        parent = await resource.get_parent()
        children = list(
            await parent.get_children(
                r_filter=ResourceFilter(
                    attribute_filters=[
                        ResourceAttributeValueFilter(attribute=attribute, value=attribute_value)
                    ],
                )
            )
        )
        return children

    async def _get_selectable_attribute(
        self, resource: Resource
    ) -> Tuple[ResourceIndexedAttribute, Any]:
        attribute_collisions = {}
        for attribute in self.selectable_indexes:
            if attribute.attributes_owner and resource.has_attributes(attribute.attributes_owner):
                attribute_value = attribute.get_value(resource.get_model())
                children = await self._get_self_and_siblings_matching_attribute(
                    resource, attribute, attribute_value
                )
                if len(children) == 1:
                    return attribute, attribute_value
                elif len(children) > 1:
                    attribute_collisions[attribute.__name__] = attribute_value
                    continue
        if len(attribute_collisions) == 0:
            raise SelectableAttributesError(
                f"Resource with ID 0x{resource.get_id().hex()} does not have a selectable attribute."
            )
        else:
            msg = []
            for collision, value in attribute_collisions.items():
                msg.append(
                    f"Resource with ID 0x{resource.get_id().hex()} cannot be uniquely identified by attribute {collision} (resource has value {value})."
                )
            raise SelectableAttributesError("\n".join(msg))

    async def _generate_name(self, resource: Resource) -> str:
        root_resource = await self._get_root_resource(resource)
        session = self._get_session(root_resource.get_id())
        most_specific_tag = list(resource.get_most_specific_tags())[0].__name__.lower()
        try:
            _, selectable_attribute_value = await self._get_selectable_attribute(resource)
            if isinstance(selectable_attribute_value, int):
                selectable_attribute_value = hex(selectable_attribute_value)
            name = f"{most_specific_tag}_{selectable_attribute_value}"
            name = re.sub(r"[^a-zA-Z0-9]", "_", name)
        except SelectableAttributesError:
            range_in_parent = await resource.get_data_range_within_parent()
            name = f"{most_specific_tag}_{hex(range_in_parent.start)}"
            name = re.sub(r"[^a-zA-Z0-9]", "_", name)
        if name in session.resource_variable_names.values():
            parent = await resource.get_parent()
            return f"{session.get_var_name(parent.get_id())}_{name}"
        return name

    def _get_session(self, resource_id: bytes) -> ScriptSession:
        session = self.script_sessions.get(resource_id, None)
        if session is None:
            session = ScriptSession()
            self.script_sessions[resource_id] = session

        return session

    def _get_script(
        self, resource_id: bytes, target_type: Optional[ActionType] = None
    ) -> List[str]:
        script_list: List[str] = []
        script_list.append(self.script_sessions[resource_id].boilerplate_header)
        for script_action in self.script_sessions[resource_id].actions:
            # Always include UNDEF actions like variable assignments
            if (
                target_type is None
                or script_action.action_type == target_type
                or script_action.action_type == ActionType.UNDEF
            ):
                script_list.append(script_action.action)
        script_list.append(self.script_sessions[resource_id].boilerplate_footer)
        script_str = "\n".join(script_list)
        script_str = self._dedent(script_str)
        try:
            res = format_str(script_str, mode=FileMode())
            script_list = res.split("\n")
        except Exception as e:
            logging.exception("Black Formatting Error:")
            logging.exception(e)
            script_list = script_str.split("\n")
        return script_list

    def _dedent(self, s: str) -> str:
        split = list(s.splitlines())
        prefix = os.path.commonprefix([line for line in split if line])
        indent_matches = re.findall(r"^\s+", prefix)
        if not indent_matches:
            return s
        indent_end_index = len(indent_matches[0])
        dedented_strings = [
            line[indent_end_index:] if line and line.startswith(prefix) else "" for line in split
        ]
        return "\n".join(dedented_strings)


async def get_child_by_range(resource: Resource, _range: Range) -> Resource:
    """
    Helper function to get child at the given offset.

    :raises SelectableAttributeError: If no child maps to the _range in the resource, or if
    multiple children map to the same range.
    """
    children = await resource.get_children()

    found_child = None
    for child in children:
        try:
            range_in_parent = await child.get_data_range_within_parent()
        except ValueError:
            # For some reason in the script builder some children do not have data_ids.
            continue
        if range_in_parent == _range:
            if found_child:
                # If two children have the same offset range, we raise an error, since the script
                #  builder cannot reasonably identify which child is being selected.
                #  This is intended behavior, see:
                #  tests/unit/test_ofrak_server.py::test_selectable_attr_err.
                found_child = None
                break
            found_child = child
    if not found_child:
        raise SelectableAttributesError(
            f"Resource with ID 0x{resource.get_id().hex()} "
            f"does not have a selectable attribute."
        )
    return found_child
